import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable

train = pd.read_csv('datafile',engine='python',encoding='latin-1',delimiter = ';')
train = np.array(train, dtype= 'int')
test =pd.read_csv('datafile', engine='python',encoding='latin-1',delimiter = ';')
test = np.array(test, dtype= 'int')#, dtype= 'int'
number_terms = int(np.amax([train[:, 2], test[:, 2]]))
number_countries = int(np.amax([train[:, 1], test[:, 1]]))

def convert(data):
    new_data = []
    for id_terms in range(1, number_terms + 1):
        id_countries = data[:,1][data[:,2]== id_terms]
        id_sales = data[:,3][data[:,2]== id_terms]
        sales  = np.zeros(number_countries)
        sales[id_countries - 1] = id_sales
        new_data.append(list(sales))
    return new_data
train=convert(train)
test=convert(test)
train = torch.FloatTensor(train)
test = torch.FloatTensor(test)

class RBM():
    def __init__(self, nv, nh, lr, momentum):
        self.W = torch.randn(nh, nv)
        self.a = torch.randn(1, nh)
        self.b = torch.randn(1, nv)
        self.lr = lr
        self.momentum = momentum
        self.vW = torch.zeros_like(self.W)
        self.va = torch.zeros_like(self.a)
        self.vb = torch.zeros_like(self.b)
    
    def temp_h(self, x):
        wx = torch.mm(x, self.W.t())
        activation = wx + self.a.expand_as(wx)
        p_h_given_v = torch.sigmoid(activation)
        return p_h_given_v, torch.bernoulli(p_h_given_v)   
        
    def temp_v(self, y):
        wy = torch.mm(y, self.W)
        activation = wy + self.b.expand_as(wy)
        p_v_given_h = torch.sigmoid(activation)
        return p_v_given_h, torch.bernoulli(p_v_given_h)
    
    def train(self, v0, vk, ph0, phk):
        dW = (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t()
        db = torch.sum((v0 - vk), 0)
        da = torch.sum((ph0 - phk), 0)
        self.vW = self.momentum * self.vW + self.lr * dW
        self.va = self.momentum * self.va + self.lr * da
        self.vb = self.momentum * self.vb + self.lr * db
        self.W += self.vW
        self.a += self.va
        self.b += self.vb

#nv = len(train[0])
#nh = 100
#lr = 0.001
#momentum=0.98
#batch_size = 3

rbm = RBM(nv, nh, lr, momentum)

nb_epoch = 490
for epoch in range(1, nb_epoch + 1):
  train_loss = 0
  s = 0.
  for id_term in range(0, number_terms - batch_size, batch_size):
    vk = train[id_term : id_term + batch_size]
    v0 = train[id_term : id_term + batch_size]
    ph0,_ = rbm.temp_h(v0)
    for k in range(10):
        hk = rbm.temp_h(vk)[1]
        vk = rbm.temp_v(hk)[1]
      vk[v0<0] = v0[v0<0]
    phk,_ = rbm.temp_h(vk)
    rbm.train(v0, vk, ph0, phk)
    train_loss += torch.mean(torch.abs(v0[v0 >= 0] - vk[v0 >= 0]))
    s += 1.
  print(str(train_loss/s))
  
test_loss = 0
s = 0.
for id_term in range(number_terms):
    v, vt = train[id_term:id_term+1], test[id_term:id_term+1]
    if len(vt[vt>=0]) > 0:
        print('for id_term')
        _,h = rbm.temp_h(v)
        _,v = rbm.temp_v(h)       
        mask = vt >= 0
        test_loss += torch.mean(torch.abs(torch.where(mask, vt, 0) - torch.where(mask, v, 0)))
        s += 1.
print(str(test_loss/s))




